from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from collections import Counter
import torchvision
import numpy as np
import torch


def load_and_partition_data(
    num_clients, alpha, batch_size, frac, rand_seed=42, dataset="cifar10"
):

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    torch.manual_seed(rand_seed)
    np.random.seed(rand_seed)

    if dataset == "cifar10":
        full_dataset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform
        )
        num_classes = 10
    else:
        full_dataset = torchvision.datasets.CIFAR100(
            root="./data", train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root="./data", train=False, download=True, transform=transform
        )
        num_classes = 100

    y_train = np.array(full_dataset.targets)
    y_test = np.array(test_dataset.targets)

    N = len(full_dataset)
    N_test = len(test_dataset)

    net_dataidx_map = {}
    net_dataidx_map_test = {}

    min_size = 0
    while min_size < 10:
        idx_batch = [[] for _ in range(num_clients)]
        idx_batch_test = [[] for _ in range(num_clients)]

        for k in range(num_classes):
            idx_k = np.where(y_train == k)[0]
            idx_k_test = np.where(y_test == k)[0]
            np.random.shuffle(idx_k)
            np.random.shuffle(idx_k_test)
            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
            proportions_train = np.array(
                [
                    p * (len(idx_j) < N / num_clients)
                    for p, idx_j in zip(proportions, idx_batch)
                ]
            )
            proportions_test = np.array(
                [
                    p * (len(idx_j) < N_test / num_clients)
                    for p, idx_j in zip(proportions, idx_batch_test)
                ]
            )
            proportions_train = proportions_train / proportions_train.sum()
            proportions_test = proportions_test / proportions_test.sum()
            proportions_train = (np.cumsum(proportions_train) * len(idx_k)).astype(int)[
                :-1
            ]
            proportions_test = (np.cumsum(proportions_test) * len(idx_k_test)).astype(
                int
            )[:-1]
            idx_batch = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions_train))
            ]
            idx_batch_test = [
                idx_j + idx.tolist()
                for idx_j, idx in zip(
                    idx_batch_test, np.split(idx_k_test, proportions_test)
                )
            ]
        min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(num_clients):
        np.random.shuffle(idx_batch[j])
        np.random.shuffle(idx_batch_test[j])
        net_dataidx_map[j] = idx_batch[j]
        net_dataidx_map_test[j] = idx_batch_test[j]

    client_train_loaders = []
    client_val_loaders = []
    client_class_distributions = []

    for i in range(num_clients):
        np.random.seed(rand_seed + i)
        num_data = len(net_dataidx_map[i])
        frac_num_data = int(frac * num_data)
        frac_indices = np.random.choice(num_data, frac_num_data, replace=False)
        train_indices = [net_dataidx_map[i][j] for j in frac_indices]

        num_data_test = len(net_dataidx_map_test[i])
        frac_num_data_test = int(min(2 * frac, 1.0) * num_data_test)
        frac_indices_test = np.random.choice(
            num_data_test, frac_num_data_test, replace=False
        )
        val_indices = [net_dataidx_map_test[i][j] for j in frac_indices_test]

        client_labels = [y_train[idx] for idx in train_indices]
        class_counts = Counter(client_labels)
        distribution = {
            cls: (
                class_counts.get(cls, 0) / len(client_labels)
                if len(client_labels) > 0
                else 0
            )
            for cls in range(num_classes)
        }
        client_class_distributions.append(distribution)

        client_train_dataset = Subset(full_dataset, train_indices)
        client_val_dataset = Subset(test_dataset, val_indices)

        g_train = torch.Generator().manual_seed(rand_seed + i)
        g_val = torch.Generator().manual_seed(rand_seed + i + num_clients)

        train_loader = DataLoader(
            client_train_dataset,
            batch_size=batch_size,
            shuffle=True,
            generator=g_train,
            drop_last=True,
        )
        val_loader = DataLoader(
            client_val_dataset,
            batch_size=batch_size,
            shuffle=True,
            generator=g_val,
            drop_last=True,
        )

        client_train_loaders.append(train_loader)
        client_val_loaders.append(val_loader)

    g_test = torch.Generator().manual_seed(rand_seed + 2 * num_clients + 1)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        generator=g_test,
        num_workers=2,
    )

    print("Data partitioning complete.")
    for i, dist in enumerate(client_class_distributions):
        print(f"Client {i} class distribution:")
        for cls in range(num_classes):
            print(f"  Class {cls}: {dist.get(cls, 0):.2f}")

    return (
        client_train_loaders,
        client_val_loaders,
        test_loader,
        client_class_distributions,
    )
